In [1]:
import audio
import hparams
from IPython.display import Audio, display, HTML
import librosa
import librosa.display
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle

os.chdir('/home/peter/wavenet-features-private/wavenet_vocoder')
RATE = hparams.hparams.sample_rate
In [2]:
!rm -f *.pickle
!scp -i ~/.ssh/no_passphrase_rsa peterpaullake@35.202.127.142:/home/peterpaullake/wavenet-features-private/wavenet_vocoder/*.pickle .
mel-0-layer-1-channel-180-mode-skip.pickle    100%   79KB 416.2KB/s   00:00    
mel-0-layer-9-channel-124-mode-skip.pickle    100%   79KB 829.7KB/s   00:00    
mel-3-layer-14-channel-18-mode-skip.pickle    100%   72KB   1.4MB/s   00:00    
mel-4-layer-12-channel-144-mode-skip.pickle   100%   66KB   1.3MB/s   00:00    
mel-6-layer-2-channel-311-mode-res.pickle     100%   78KB   1.5MB/s   00:00    
mel-7-layer-1-channel-180-mode-skip.pickle    100%  108KB   2.1MB/s   00:00    
mel-7-layer-12-channel-144-mode-skip.pickle   100%  108KB   2.1MB/s   00:00    
mel-7-layer-14-channel-18-mode-skip.pickle    100%  108KB   2.2MB/s   00:00    
mel-7-layer-2-channel-311-mode-res.pickle     100%  108KB   2.1MB/s   00:00    
mel-7-layer-9-channel-124-mode-skip.pickle    100%  108KB   2.2MB/s   00:00    
mel-8-layer-1-channel-180-mode-skip.pickle    100%  108KB   2.1MB/s   00:00    
mel-8-layer-12-channel-144-mode-skip.pickle   100%  108KB   2.2MB/s   00:00    
mel-8-layer-2-channel-311-mode-res.pickle     100%  108KB   2.1MB/s   00:00    
mel-8-layer-9-channel-124-mode-skip.pickle    100%  108KB   2.1MB/s   00:00    
mel-9-layer-1-channel-180-mode-skip.pickle    100%  108KB   2.1MB/s   00:00    
mel-9-layer-12-channel-144-mode-skip.pickle   100%  108KB   2.1MB/s   00:00    
mel-9-layer-2-channel-311-mode-res.pickle     100%  108KB   2.1MB/s   00:00    
mel-9-layer-9-channel-124-mode-skip.pickle    100%  108KB   2.1MB/s   00:00    
In [5]:
DPI = 100

def config_time_axis(ax, num, duration):
    def xformat(value, pos):
        return '%.2f' % (value / num * duration)
    ax.xaxis.set_major_formatter(plt.FuncFormatter(xformat))
    ax.set_xlabel('Time (seconds)')
    
def config_freq_axis(ax, num_bins):
    fmin, fmax = 125, 7600 # from Tacotron2 hparams.py
    mel_to_f = lambda mel : 700 * (np.exp(mel / 1127) - 1)
    f_to_mel = lambda f : 1127 * np.log(1 + f / 700)
    lower_mel = f_to_mel(fmin)
    upper_mel = f_to_mel(fmax)
    
    def yformat(value, pos):
        # value = 1 + value # num_bins - 1 - value
        step = (upper_mel - lower_mel) / num_bins
        mel = lower_mel + value * step
        return '%.2f' % mel_to_f(mel)
    ax.yaxis.set_major_formatter(plt.FuncFormatter(yformat))
    ax.set_ylabel('Frequency (Hz)')
    
def plot_mel(fig, ax, title, mel):
    num = len(mel)
    duration = num * audio.get_hop_size() / RATE
    config_time_axis(ax, num, duration)
    config_freq_axis(ax, mel.shape[1])
    
    ax.set_title(title)
    im = ax.imshow(mel, origin='lower', cmap='coolwarm', aspect='auto')
    fig.colorbar(im, ax=ax)

def display_lc_mel(mel, text=None):
    fig, ax = plt.subplots(figsize=(2, 2))
    title = '' if text is None else text
    plot_mel(fig, ax, title, mel)
    plt.show()
    
def display_wave(wave, title):
    fig, axes = plt.subplots(2, figsize=(10,4), dpi=100)
    
    width = RATE // 100
    mid = len(wave) // 2 #int(0.36 * RATE)
    start = mid - width
    end = mid + width
    
    axes[0].set_title('Entire waveform')
    axes[0].set_ylim(-1, 1)
    axes[0].plot(wave)
    config_time_axis(axes[0], len(wave), len(wave) / RATE)

    ymin, ymax = wave[start:end].min(), wave[start:end].max()
    rec = matplotlib.patches.Rectangle((start, ymin),
                                       end - start,
                                       ymax - ymin,
                                       linewidth=2,
                                       linestyle=':',
                                       edgecolor='r',
                                       facecolor='none',
                                       zorder=3)
    axes[0].add_patch(rec)

    axes[1].set_title('Zoomed in')
    axes[1].plot(wave[start:end])
    config_time_axis(axes[1], end - start, (end - start) / RATE)#, t0=start/RATE)
    
    fig.suptitle(title, fontsize=14, y=1.01)
    # plt.tight_layout()
    plt.subplots_adjust(hspace=0.7)
    plt.show()
    display(Audio(wave, rate=RATE))
    
def center_and_scale_wave(wave):
    wave2 = wave[:]
    wave2 -= np.mean(wave)
    wave2 *= 100
    q = np.quantile(wave2, 1e-3)
    if q != 0:
        wave2 /= q
    return wave2
    
    mean = np.mean(wave)
    std_dev = np.std(wave)
    wave_without_anomalies = wave[:]
    for i, value in enumerate(wave_without_anomalies):
        if np.abs((value - mean) / std_dev) >= 1:
            wave_without_anomalies[i] = 0
    centered_and_scaled_wave = wave[:]
    centered_and_scaled_wave -= np.mean(wave_without_anomalies)
    centered_and_scaled_wave /= np.max(np.abs(wave_without_anomalies))
    return np.clip(centered_and_scaled_wave, -1, 1)

def display_spec(wave):
    # wave, sr = librosa.load(librosa.util.example_audio_file())
    sr = RATE
    plt.figure(figsize=(10, 10), dpi=DPI)
    D = librosa.amplitude_to_db(librosa.stft(wave), ref=np.max)
    plt.subplot(4, 2, 1)
    librosa.display.specshow(D, y_axis='log')
    plt.colorbar(format='%+2.0f dB')
    plt.title('Log-frequency power spectrogram', y=1.01)
    plt.show()

mel_paths = ['../Tacotron-2/tacotron_output/eval/speech-mel-00001.npy',
             '../Tacotron-2/tacotron_output/eval/speech-mel-00002.npy',
             '../Tacotron-2/tacotron_output/eval/speech-mel-00003.npy',
             '../Tacotron-2/tacotron_output/eval/speech-mel-00004.npy',
             '../Tacotron-2/tacotron_output/eval/speech-mel-00005.npy',
             '../Tacotron-2/tacotron_output/eval/speech-mel-00006.npy',
             '../Tacotron-2/tacotron_output/eval/speech-mel-00007.npy',
             '../Tacotron-2/tacotron_output/eval/diagonal.npy',
             '../Tacotron-2/tacotron_output/eval/straight.npy',
             '../Tacotron-2/tacotron_output/eval/zeros.npy']
mel_names = ['\'' + line[:-1] + '\'' for line in open('../text_list.txt').readlines()] + ['Chirp',
                                                                                          'Constant frequency',
                                                                                          'Zeros']
mel_waves = pickle.load(open('../sounds.pickle', 'rb'))['data']

def display_experiment(wave, props):
    '''title = 'Maximizing %s output, layer %d, channel %d' % (props['mode'],
                                                            props['layer_id'],
                                                            props['channel_id'])'''
    mel_path = mel_paths[props['mel_id']]
    mel = np.load(mel_path)
    if props['mel_id'] in range(7, 10):
        mel = mel[::2]
    mel = np.interp(mel, (0, 4), (0, 1))
    mel = np.swapaxes(mel, 0, 1)
    display_lc_mel(mel, 'Local conditioning features:\n' + mel_names[props['mel_id']])
    if props['mel_id'] < len(mel_waves):
        display(Audio(mel_waves[props['mel_id']], rate=RATE))
    title = ''
    display_wave(wave, title)
    display_spec(wave)
    display(HTML('<hr style="border:1px solid black;">'))

pickle_paths = os.listdir()
pickle_paths = list(filter(lambda path : path.endswith('.pickle'), pickle_paths))

def display_unit(layer_id, channel_id, mode):
    title = 'Maximizing activation of layer %d channel %d %s output' % (layer_id,
                                                                        channel_id,
                                                                        'residual' if mode == 'res' \
                                                                        else 'skip')
    display(HTML('<h1>%s</h1>' % title))
    for path in pickle_paths:
        raw_wave, props = pickle.load(open(path, 'rb'))['data']
        if props['layer_id'] != layer_id or props['channel_id'] != channel_id or props['mode'] != mode:
            continue
        wave = np.tanh(raw_wave)
        wave = center_and_scale_wave(wave)
        display_experiment(wave, props)

display_unit(12, 144, 'skip')
display_unit(2, 311, 'res')
display_unit(1, 180, 'skip')
display_unit(9, 124, 'skip')
display_unit(14, 18, 'skip')

Maximizing activation of layer 12 channel 144 skip output

Your browser does not support the audio element.

Your browser does not support the audio element.

Your browser does not support the audio element.

Your browser does not support the audio element.
Your browser does not support the audio element.

Maximizing activation of layer 2 channel 311 residual output

Your browser does not support the audio element.

Your browser does not support the audio element.
Your browser does not support the audio element.

Your browser does not support the audio element.

Your browser does not support the audio element.

Maximizing activation of layer 1 channel 180 skip output

Your browser does not support the audio element.

Your browser does not support the audio element.

Your browser does not support the audio element.

Your browser does not support the audio element.
Your browser does not support the audio element.

Maximizing activation of layer 9 channel 124 skip output

Your browser does not support the audio element.
Your browser does not support the audio element.

Your browser does not support the audio element.

Your browser does not support the audio element.

Your browser does not support the audio element.

Maximizing activation of layer 14 channel 18 skip output

Your browser does not support the audio element.

Your browser does not support the audio element.
Your browser does not support the audio element.

In [4]:
unit_vars = pickle.load(open('../unit-vars.pickle', 'rb'))['data']
mean_unit_vars = {}

for key in ['res', 'skip']:
    mean_unit_vars[key] = np.stack(list(map(lambda x : x[key], unit_vars))).mean(axis=0)

assert len(mean_unit_vars['res']) == len(mean_unit_vars['skip'])

for layer_id in range(len(mean_unit_vars['res'])):
    res_id = np.argmax(mean_unit_vars['res'][layer_id])
    res_var = mean_unit_vars['res'][layer_id][res_id]
    skip_id = np.argmax(mean_unit_vars['skip'][layer_id])
    skip_var = mean_unit_vars['skip'][layer_id][skip_id]
    s = 'Layer %d highest variance res/skip channels are %d, %d with variances %f and %f'
    t = (layer_id, res_id, skip_id, res_var, skip_var)
    print(s % t)
Layer 0 highest variance res/skip channels are 440, 245 with variances 6.221451 and 113.685120
Layer 1 highest variance res/skip channels are 311, 180 with variances 31.539085 and 262.544800
Layer 2 highest variance res/skip channels are 311, 167 with variances 48.508530 and 176.072357
Layer 3 highest variance res/skip channels are 311, 167 with variances 34.334328 and 81.397980
Layer 4 highest variance res/skip channels are 311, 167 with variances 22.511044 and 102.421120
Layer 5 highest variance res/skip channels are 123, 167 with variances 15.746262 and 62.437580
Layer 6 highest variance res/skip channels are 123, 16 with variances 9.162777 and 23.227379
Layer 7 highest variance res/skip channels are 123, 144 with variances 5.131449 and 32.383545
Layer 8 highest variance res/skip channels are 123, 144 with variances 3.626981 and 39.800583
Layer 9 highest variance res/skip channels are 419, 124 with variances 4.968760 and 45.617863
Layer 10 highest variance res/skip channels are 419, 144 with variances 4.603286 and 55.304104
Layer 11 highest variance res/skip channels are 419, 144 with variances 2.301641 and 28.787510
Layer 12 highest variance res/skip channels are 21, 144 with variances 0.111429 and 13.627382
Layer 13 highest variance res/skip channels are 21, 147 with variances 0.101510 and 0.807069
Layer 14 highest variance res/skip channels are 21, 18 with variances 0.067813 and 0.515050
Layer 15 highest variance res/skip channels are 316, 190 with variances 0.039555 and 0.165396
Layer 16 highest variance res/skip channels are 255, 132 with variances 0.044264 and 0.153228
Layer 17 highest variance res/skip channels are 75, 132 with variances 0.041973 and 0.378613
Layer 18 highest variance res/skip channels are 230, 132 with variances 0.069366 and 0.474421
Layer 19 highest variance res/skip channels are 489, 132 with variances 0.079235 and 0.238216
Layer 20 highest variance res/skip channels are 243, 202 with variances 0.090310 and 0.051206
Layer 21 highest variance res/skip channels are 492, 190 with variances 0.095038 and 0.045689
Layer 22 highest variance res/skip channels are 492, 112 with variances 0.077813 and 0.036163
Layer 23 highest variance res/skip channels are 492, 67 with variances 0.040618 and 0.017965
In [ ]: